/// Source : https://leetcode.com/problems/minimize-malware-spread-ii/
/// Author : liuyubobobo
/// Time   : 2018-10-21

#include <iostream>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <cassert>

using namespace std;


/// Using Union-Find
/// Time Complexity: O(v^2 * a(v))
/// Space Complexity: O(v)
class UF{

private:
    vector<int> parent, sz;

public:
    UF(int n){
        parent.clear();
        for(int i = 0 ; i < n ; i ++){
            parent.push_back(i);
            sz.push_back(1);
        }
    }

    int find(int p){
        if( p != parent[p] )
            parent[p] = find( parent[p] );
        return parent[p];
    }

    bool isConnected(int p , int q){
        return find(p) == find(q);
    }

    void unionElements(int p, int q){

        int pRoot = find(p);
        int qRoot = find(q);

        if( pRoot == qRoot )
            return;

        parent[pRoot] = qRoot;
        sz[qRoot] += sz[pRoot];
    }

    int size(int p){
        return sz[find(p)];
    }
};


class Solution {

private:
    int n;

public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {

        n = graph.size();
        vector<bool> removed(n, false);
        for(int v: initial)
            removed[v] = true;

        UF uf(n);
        for(int i = 0; i < n; i ++)
            if(!removed[i])
                for(int j = 0; j < n; j ++)
                    if(!removed[j] && graph[i][j])
                        uf.unionElements(i, j);

        unordered_map<int, unordered_set<int>> infection;
        for(int v: initial)
            for(int i = 0; i < n; i ++)
                if(!removed[i] && graph[v][i])
                    infection[uf.find(i)].insert(v);

        unordered_map<int, int> effective;
        for(const pair<int, unordered_set<int>>& p: infection)
            if(p.second.size() == 1)
                effective[*p.second.begin()] += uf.size(p.first);

        int res = *min_element(initial.begin(), initial.end()), best = 0;
        for(const pair<int, int>& p: effective)
            if(p.second > best){
                best = p.second;
                res = p.first;
            }
            else if(p.second == best)
                res = min(res, p.first);
        return res;
    }
};


int main() {

    vector<vector<int>> g1 = {{1,1,0,0},{1,1,1,0},{0,1,1,1},{0,0,1,1}};
    vector<int> initial1 = {3,0};
    cout << Solution().minMalwareSpread(g1, initial1) << endl;
    // 0

    vector<vector<int>> g2 = {
            {1,0,0,0,0,0,0,0,0},
            {0,1,0,0,0,0,0,0,0},
            {0,0,1,0,1,0,1,0,0},
            {0,0,0,1,0,0,0,0,0},
            {0,0,1,0,1,0,0,0,0},
            {0,0,0,0,0,1,0,0,0},
            {0,0,1,0,0,0,1,0,0},
            {0,0,0,0,0,0,0,1,0},
            {0,0,0,0,0,0,0,0,1}};
    vector<int> initial2 = {6,0,4};
    cout << Solution().minMalwareSpread(g2, initial2) << endl;
    // 0

    vector<vector<int>> g3 = {{1,1,0},{1,1,1},{0,1,1}};
    vector<int> initial3 = {0,1};
    cout << Solution().minMalwareSpread(g3, initial3) << endl;
    // 1

    vector<vector<int>> g4 = {{1,1,0},{1,1,0},{0,0,1}};
    vector<int> initial4 = {0,1};
    cout << Solution().minMalwareSpread(g4, initial4) << endl;
    // 0

    return 0;
}
